# -*- coding: utf-8 -*-
"""
Copyright 2021 NXP
All rights reserved.

SPDX-License-Identifier: BSD-3-Clause

author: Kaleb Belete
"""

import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transform
import torchvision.datasets as dataset
from torch.utils.data.dataloader import DataLoader as Load
from bokeh.plotting import figure
from bokeh.io import show
from bokeh.models import ColumnDataSource, Range1d, LinearAxis
import numpy as np

model_save_path = '..\\Pytorch Models\\'
data_path = '..\\dataset\\'

# Training Parameters
epoch_Tot = 1
batch_size = 16
learning_rate = .001


# Totensor switches image range from [0,255] to [0,1]. 
# Normalizing with .5 makes image [-1,1]
trans = transform.Compose([transform.ToTensor(), 
                           transform.Normalize((.5,.5,.5), (.5,.5,.5))])
# Training images are cropped in random location and given padding. 
# 50% chance to flip image horizontally. Image in range [-1,1]
train_trans = transform.Compose([transform.RandomCrop(size=[32,32], padding=4),
                                 transform.RandomHorizontalFlip(), 
                                  transform.ToTensor(), 
                                  transform.Normalize((.5,.5,.5), (.5,.5,.5))]) 
  
# Load Cifar Datasets
train_data = dataset.CIFAR10(root=data_path, train=True, 
                             transform=train_trans, 
                             download=True)
test_data = dataset.CIFAR10(root=data_path, 
                            train=False, 
                            transform=trans)

# Batch the training and testing datasets
train_loader = Load(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = Load(dataset=test_data, batch_size=batch_size, shuffle=False)

channel = 32
class NeuralNet(nn.Module):                                             #Create model class to inherit from nn.Module                              
    def __init__(self):                     
        super(NeuralNet, self).__init__()   
        self.layer1 = nn.Sequential(                                    # Container Class                             
            nn.Conv2d(3, channel, kernel_size=3, padding=1),            # Extracts 32 (32x32) feature maps from RGB image
            nn.BatchNorm2d(channel),                                    # Normalize extracted features to each batch        
            nn.ReLU(),                                                  # Activate normalized data for non-linearity (3)
            nn.Conv2d(channel, channel, kernel_size=3, padding=1),      # Extracts 32 (32x32) feature maps from processed feature maps (1)
            nn.BatchNorm2d(channel),                                    # Normalize extracted features to each batch
            nn.ReLU(),                                                  # Activate normalized data for non-linearity (3)                                          
            nn.MaxPool2d(kernel_size=2, stride=2))                      # Pooling to reduce training parameters and generalize learned features
        self.layer2 = nn.Sequential(
            nn.Conv2d(channel, channel*2, kernel_size=3, padding=1),    # Extracts 64 (16x16) feature maps from processed pooled features
            nn.BatchNorm2d(channel*2),                                  # Normalize extracted features to each batch
            nn.ReLU(),                                                  # Activate normalized features for non-linearity
            nn.MaxPool2d(kernel_size=2, stride=2))                      # Pooling to reduce training parameters and generalize learned features
        self.layer3 = nn.Sequential(
            nn.Conv2d(channel*2, channel*2, kernel_size=3, padding=1),  # Extracts 64 (8x8) feature maps from processed pooled features
            nn.BatchNorm2d(channel*2),                                  # Normalize extracted features to each batch
            nn.ReLU(),                                                  # Activate normalized data for non-linearity
            nn.MaxPool2d(kernel_size=2, stride=2))                      # Reduce training parameters (4x4) connected to Linear Layer
        self.fcLayer = nn.Linear(4 * 4 * channel*2, 10)                 # Convert Learned training into a classifier output

        

    def forward(self, x):                                               #Path through the model using nn.Sequential container class
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.reshape(out.size(0), -1)                              # Reshape 2D weight matrix to 1D Array
        out = self.fcLayer(out)


        return out


#Model instance, Training Optimizer and Loss Function
model = NeuralNet()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

#Track training progress
total_step = len(train_loader)
loss_list = []
accuracy_track = []


#Training
model.train()
for epoch in range(epoch_Tot):
    #Track Epoch training time
    t0 = time.time()
    
    # Passes 1 batch at a time before updating weights
    for batch_idx, (image, label) in enumerate(train_loader):
        #Zero the parameter Gradients
        optimizer.zero_grad()
        
        #Forward Pass: Compute output 
        #Backward Pass: Compute error contribution by nodes and update weights
        output = model(image)
        loss = loss_fn(output, label)
        loss_list.append(loss.item())
        loss.backward()
        optimizer.step()
        
        # Training Accuracy calculations
        total = label.size(0)                           # Batch size 
        _,predicted = torch.max(output.data, 1)         # Predicted batch Output
        correct = (predicted == label).sum().item()     # Correct batch Labels 
        accuracy_track.append(correct / total)          # Append batch accuracy
        
        #Print Training Progress
        if (batch_idx + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}],  Loss: {:.4f}, Accuracy: {:.2f}%'
                  .format(epoch + 1, epoch_Tot, batch_idx + 1, total_step, 
                          loss.item(), (correct / total) * 100))
    print('Epoch [{}/{}], {:.1f} seconds' 
          .format(epoch + 1, epoch_Tot, time.time()-t0))


#Set Model to Evaluation Mode (Effects BatchNorm)
model.eval()
correct = 0
total = 0
test_loss = 0

#Testing
with torch.no_grad():
    for image, label in test_loader:
        #Forward pass image through the model
        output = model(image)
        # Track total accuracy and average loss across validation set
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
        test_loss += loss_fn(output, label).item()
    test_loss /= len(test_loader)

    print('Test Accuracy: {}/{} ({:.2f}%) Avg Loss: {:.4f}' .format(correct, 
          total, (correct / total) * 100, test_loss))

# Save the model with input shape (Batch size, channels (RGB), img width, img height)
input_shape = torch.randn(1,3,32,32)
#Specify Model's Input and Output node names 
torch.onnx.export(model, input_shape, model_save_path + 'Cifar.onnx', 
                  input_names= ['input'], output_names= ['output'])

# Plot training accuracy and loss characteristics
epoch_list = np.linspace(0,epoch_Tot, len(loss_list))
source = ColumnDataSource(data={'Epoch': epoch_list, 'Training Loss': loss_list, 
                                'Accuracy': accuracy_track})
plot = figure(title="Pytorch Training Results", width=900, x_axis_label='Epoch',
              y_axis_label='Loss', y_range=(0,1))
plot.xaxis.ticker = list(range(1, epoch_Tot + 1))
plot.extra_y_ranges = {'Accuracy': Range1d(0,100)}
plot.add_layout(LinearAxis(y_range_name='Accuracy', 
                           axis_label='Accuracy(%)'), 'right')
plot.line(x=epoch_list, y=loss_list, color='blue', 
          legend_label='Training loss', line_width=1)
plot.line(x=epoch_list, y=accuracy_track, color='red', 
          legend_label='Accuracy', line_width=1)
show(plot)
